#!/usr/bin/env python3
# Copyright lowRISC contributors.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
r"""Convert mako template to Hjson register description
"""
import argparse
import logging as log
import textwrap
from math import ceil, log2
from pathlib import Path

import hjson
from mako.template import Template
from tabulate import tabulate

DIGEST_SUFFIX = "_DIGEST"
DIGEST_SIZE = 8

TABLE_HEADER_COMMENT = '''<!--
DO NOT EDIT THIS FILE DIRECTLY.
It has been generated with hw/ip/otp_ctrl/util/translate-mmap.py
-->

'''

# memory map source
MMAP_DEFINITION_FILE = "../data/otp_ctrl_mmap.hjson"
# documentation tables to generate
PARTITIONS_TABLE_FILE = "../doc/otp_ctrl_partitions.md"
DIGESTS_TABLE_FILE = "../doc/otp_ctrl_digests.md"
MMAP_TABLE_FILE = "../doc/otp_ctrl_mmap.md"
# code templates to render
TEMPLATES = ["../data/otp_ctrl.hjson.tpl", "../rtl/otp_ctrl_part_pkg.sv.tpl"]


def wrapped_docstring():
    '''Return a text-wrapped version of the module docstring'''
    paras = []
    para = []
    for line in __doc__.strip().split('\n'):
        line = line.strip()
        if not line:
            if para:
                paras.append('\n'.join(para))
                para = []
        else:
            para.append(line)
    if para:
        paras.append('\n'.join(para))

    return '\n\n'.join(textwrap.fill(p) for p in paras)


def check_bool(x):
    """check_bool checks if input 'x' either a bool or
       one of the following strings: ["true", "false"]

        It returns value as Bool type.
    """
    if isinstance(x, bool):
        return x
    if not x.lower() in ["true", "false"]:
        log.error("{} is not a boolean value.".format(x))
        exit(1)
    else:
        return (x.lower() == "true")


def check_int(x):
    """check_int checks if input 'x' is decimal integer.

        It returns value as an int type.
    """
    if isinstance(x, int):
        return x
    if not x.isdecimal():
        log.error("{} is not a decimal number".format(x))
        exit(1)
    return int(x)


def validate(config):
    offset = 0
    num_part = 0
    for part in config["partitions"]:
        num_part += 1
        # Defaults
        part.setdefault("offset", offset)
        part.setdefault("name", "unknown_name")
        part.setdefault("variant", "Unbuffered")
        part.setdefault("size", "0")
        part.setdefault("secret", "false")
        part.setdefault("sw_digest", "false")
        part.setdefault("hw_digest", "false")
        part.setdefault("write_lock", "none")
        part.setdefault("read_lock", "none")
        part.setdefault("key_sel", "NoKey")
        log.info("Partition {} at offset {} with size {}".format(
            part["name"], part["offset"], part["size"]))

        # make sure these are boolean types (simplifies the mako templates)
        part["secret"] = check_bool(part["secret"])
        part["sw_digest"] = check_bool(part["sw_digest"])
        part["hw_digest"] = check_bool(part["hw_digest"])

        # Sanity checks
        if part["variant"] not in ["Unbuffered", "Buffered", "LifeCycle"]:
            log.error("Invalid partition type {}".format(part["variant"]))
            exit(1)

        if part["key_sel"] not in [
                "NoKey", "Secret0Key", "Secret1Key", "Secret2Key"
        ]:
            log.error("Invalid key sel {}".format(part["key_sel"]))
            exit(1)

        if check_bool(part["secret"]) and part["key_sel"] == "NoKey":
            log.error(
                "A secret partition needs a key select value other than NoKey")
            exit(1)

        if part["write_lock"].lower() not in ["digest", "csr", "none"]:
            log.error("Invalid value for write_lock")
            exit(1)

        if part["read_lock"].lower() not in ["digest", "csr", "none"]:
            log.error("Invalid value for read_lock")
            exit(1)

        if part["sw_digest"] and part["hw_digest"]:
            log.error(
                "Partition cannot support both a SW and a HW digest at the same time."
            )
            exit(1)

        if part["variant"] == "Unbuffered" and not part["sw_digest"]:
            log.error(
                "Unbuffered partitions without digest are not supported at the moment."
            )
            exit(1)

        if not part["sw_digest"] and not part["hw_digest"]:
            if part["write_lock"].lower(
            ) == "digest" or part["read_lock"].lower() == "digest":
                log.error(
                    "A partition can only be write/read lockable if it has a hw or sw digest."
                )
                exit(1)

        if check_int(part["offset"]) % 8:
            log.error("Partition offset must be 64bit aligned")
            exit(1)

        if check_int(part["size"]) % 8:
            log.error("Partition size must be 64bit aligned")
            exit(1)

        # Loop over items within a partition
        for item in part["items"]:
            item.setdefault("name", "unknown_name")
            item.setdefault("size", "0")
            item.setdefault("isdigest", "false")
            item.setdefault("offset", offset)
            log.info("> Item {} at offset {} with size {}".format(
                item["name"], offset, item["size"]))
            offset += check_int(item["size"])

        # Place digest at the end of a partition.
        if part["sw_digest"] or part["hw_digest"]:
            part["items"].append({
                "name":
                part["name"] + DIGEST_SUFFIX,
                "size":
                DIGEST_SIZE,
                "offset":
                check_int(part["offset"]) + check_int(part["size"]) -
                DIGEST_SIZE,
                "isdigest":
                "True"
            })

            log.info("> Adding digest {} at offset {} with size {}".format(
                part["name"] + DIGEST_SUFFIX, offset, DIGEST_SIZE))
            offset += DIGEST_SIZE

        if len(part["items"]) == 0:
            log.warning("Partition does not contain any items.")

        # check offsets and size
        if offset > check_int(part["offset"]) + check_int(part["size"]):
            log.error("Not enough space in partitition "
                      "{} to accommodate all items. Bytes available "
                      "= {}, bytes requested = {}".format(
                          part["name"], part["size"], offset - part["offset"]))
            exit(1)

        offset = check_int(part["offset"]) + check_int(part["size"])

    otp_size = check_int(config["otp"]["depth"]) * check_int(
        config["otp"]["width"])
    config["otp"]["size"] = otp_size
    config["otp"]["addr_width"] = ceil(log2(check_int(config["otp"]["depth"])))
    config["otp"]["byte_addr_width"] = ceil(log2(check_int(otp_size)))

    if offset > otp_size:
        log.error(
            "OTP is not big enough to store all partitions. Bytes available {}, bytes required {}",
            otp_size, offset)
        exit(1)

    log.info("Total number of partitions: {}".format(num_part))
    log.info("Bytes available in OTP: {}".format(otp_size))
    log.info("Bytes required for partitions: {}".format(offset))


def create_partitions_table(config):
    header = [
        "Partition", "Secret", "Buffered", "WR Lockable", "RD Lockable",
        "Description"
    ]
    table = [header]
    colalign = ("center", ) * len(header)

    for part in config["partitions"]:
        is_secret = "yes" if check_bool(part["secret"]) else "no"
        is_buffered = "yes" if part["variant"] in ["Buffered", "LifeCycle"
                                                   ] else "no"
        wr_lockable = "no"
        if part["write_lock"].lower() in ["csr", "digest"]:
            wr_lockable = "yes (" + part["write_lock"] + ")"
        rd_lockable = "no"
        if part["read_lock"].lower() in ["csr", "digest"]:
            rd_lockable = "yes (" + part["read_lock"] + ")"
        # remove newlines
        desc = ' '.join(part["desc"].split())
        row = [
            part["name"], is_secret, is_buffered, wr_lockable, rd_lockable,
            desc
        ]
        table.append(row)

    return tabulate(table,
                    headers="firstrow",
                    tablefmt="pipe",
                    colalign=colalign)


def create_mmap_table(config):
    header = [
        "Index", "Partition", "Size [B]", "Access Granule", "Item",
        "Byte Address", "Size [B]"
    ]
    table = [header]
    colalign = ("center", ) * len(header)

    for k, part in enumerate(config["partitions"]):
        granule = "64bit" if check_bool(part["secret"]) else "32bit"
        for j, item in enumerate(part["items"]):
            if j == 0:
                row = [str(k), part["name"], str(part["size"]), granule]
            else:
                row = ["", "", "", ""]

            if check_bool(item["isdigest"]):
                name = "[{}](#Reg_{}_0)".format(item["name"],
                                                item["name"].lower())
            else:
                name = item["name"]

            row.extend([
                name, "0x{:03X}".format(check_int(item["offset"])),
                str(item["size"])
            ])

            table.append(row)

    return tabulate(table,
                    headers="firstrow",
                    tablefmt="pipe",
                    colalign=colalign)


def create_digests_table(config):
    header = ["Digest Name", " Affected Partition", "Calculated by HW"]
    table = [header]
    colalign = ("center", ) * len(header)

    for part in config["partitions"]:
        if check_bool(part["hw_digest"]) or check_bool(part["sw_digest"]):
            is_hw_digest = "yes" if check_bool(part["hw_digest"]) else "no"
            for item in part["items"]:
                if check_bool(item["isdigest"]):
                    name = "[{}](#Reg_{}_0)".format(item["name"],
                                                    item["name"].lower())
                    row = [name, part["name"], is_hw_digest]
                    table.append(row)
                    break
            else:
                log.error(
                    "Partition with digest does not contain a digest item")
                exit(1)

    return tabulate(table,
                    headers="firstrow",
                    tablefmt="pipe",
                    colalign=colalign)


def main():
    log.basicConfig(level=log.INFO,
                    format="%(asctime)s - %(message)s",
                    datefmt="%Y-%m-%d %H:%M")

    parser = argparse.ArgumentParser(
        prog="translate-mmap",
        description=wrapped_docstring(),
        formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.parse_args()

    with open(MMAP_DEFINITION_FILE, 'r') as infile:
        config = hjson.load(infile)
        validate(config)

        with open(PARTITIONS_TABLE_FILE, 'w') as outfile:
            outfile.write(TABLE_HEADER_COMMENT + create_partitions_table(config))

        with open(DIGESTS_TABLE_FILE, 'w') as outfile:
            outfile.write(TABLE_HEADER_COMMENT + create_digests_table(config))

        with open(MMAP_TABLE_FILE, 'w') as outfile:
            outfile.write(TABLE_HEADER_COMMENT + create_mmap_table(config))

        # render all templates
        for template in TEMPLATES:
            with open(template, 'r') as tplfile:
                tpl = Template(tplfile.read())
                with open(
                        Path(template).parent.joinpath(Path(template).stem),
                        'w') as outfile:
                    outfile.write(tpl.render(config=config))


if __name__ == "__main__":
    main()
